// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

//go:build ignore
// +build ignore

package main

import (
	"bufio"
	"bytes"
	"flag"
	"fmt"
	"io"
	"net/http"
	"os"
	"os/exec"
	"path/filepath"
	"regexp"
	"strconv"
	"text/template"
)

// Min and max record/message numbers.
const (
	minRecordNum = 1000
	maxRecordNum = 3000
)

// TemplateParams is the data used in evaluating the template.
type TemplateParams struct {
	Command     string
	RecordTypes map[int]string
	RecordNames map[string]string
}

const fileTemplate = `
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

// Code generated by {{.Command}} - DO NOT EDIT.

package auparse

import (
	"errors"
	"fmt"
	"strconv"
	"strings"
)

var errInvalidAuditMessageTypName = errors.New("invalid message type")

// AuditMessageType represents an audit message type used by the kernel.
type AuditMessageType uint16

// List of AuditMessageTypes.
const(
{{- range $recordNum, $recordType := .RecordTypes }}
	{{ $recordType}} AuditMessageType = {{ $recordNum }}
{{- end }}
)

var auditMessageTypeToName = map[AuditMessageType]string{
{{- range $recordType, $recordName := .RecordNames }}
	{{ $recordType }}: "{{ $recordName }}",
{{- end }}
}

func (t AuditMessageType) String() string {
	name, found := auditMessageTypeToName[t]
	if found {
		return name
	}

	return fmt.Sprintf("UNKNOWN[%d]", uint16(t))
}

func (t AuditMessageType) MarshalText() (text []byte, err error) {
	return []byte(strings.ToLower(t.String())), nil
}

func (t *AuditMessageType) UnmarshalText(text []byte) error {
	messageType, err := GetAuditMessageType(string(text))
	if err != nil {
		return err
	}
	*t = messageType
	return nil
}

var auditMessageNameToType = map[string]AuditMessageType{
{{- range $recordType, $recordName := .RecordNames }}
	"{{ $recordName }}": {{ $recordType }},
{{- end }}
}

// GetAuditMessageType accepts a type name and returns its numerical
// representation. If the name is unknown and error is returned.
func GetAuditMessageType(name string) (AuditMessageType, error) {
	name = strings.ToUpper(name)

	typ, found := auditMessageNameToType[name]
	if found {
		return typ, nil
	}

	// Parse type from UNKNOWN[1329].
	start := strings.IndexByte(name, '[')
	if start == -1 {
		return 0, errInvalidAuditMessageTypName
	}
	name = name[start+1:]

	end := strings.IndexByte(name, ']')
	if end == -1 {
		return 0, errInvalidAuditMessageTypName
	}
	name = name[:end]

	num, err := strconv.ParseUint(name, 10, 16)
	if err != nil {
		return 0, errInvalidAuditMessageTypName
	}

	return AuditMessageType(num), nil
}
`

var tmpl = template.Must(template.New("message_types").Parse(fileTemplate))

var headers = []string{
	`https://raw.githubusercontent.com/linux-audit/audit-userspace/v4.0.2/lib/audit-records.h`,
	`https://raw.githubusercontent.com/linux-audit/audit-userspace/v4.0.2/lib/msg_typetab.h`,
}

func DownloadFile(url, destinationDir string) (string, error) {
	resp, err := http.Get(url)
	if err != nil {
		return "", fmt.Errorf("http get failed: %v", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		return "", fmt.Errorf("download failed with http status %v", resp.StatusCode)
	}

	name := filepath.Join(destinationDir, filepath.Base(url))
	f, err := os.Create(name)
	if err != nil {
		return "", fmt.Errorf("failed to create output file: %v", err)
	}

	_, err = io.Copy(f, resp.Body)
	if err != nil {
		return "", fmt.Errorf("failed to write file to disk: %v", err)
	}

	return name, nil
}

var (
	// nameMappingRegex is used to parse name mappings from msg_typetab.h.
	nameMappingRegex = regexp.MustCompile(`^_S\((AUDIT_\w+),\s+"(\w+)"`)

	// recordTypeDefinitionRegex is used to parse type definitions from audit
	// header files.
	recordTypeDefinitionRegex = regexp.MustCompile(`^#define\s+(AUDIT_\w+)\s+(\d+)`)

	// nameExtractorRegex extracts a name from constants that did not have
	// an explicit name mapping in msg_typetab.h.
	nameExtractorRegex = regexp.MustCompile(`^AUDIT_(\w+)`)
)

func readMessageTypeTable() (map[string]string, error) {
	f, err := os.Open("msg_typetab.h")
	if err != nil {
		return nil, err
	}
	defer f.Close()

	constantToStringName := map[string]string{}
	s := bufio.NewScanner(f)
	for s.Scan() {
		matches := nameMappingRegex.FindStringSubmatch(s.Text())
		if len(matches) == 3 {
			constantToStringName[matches[1]] = matches[2]
		}
	}

	return constantToStringName, s.Err()
}

func readRecordTypes() (map[string]int, error) {
	out, err := exec.Command("gcc", "-E", "-dD", "audit-records.h").Output()
	if err != nil {
		return nil, fmt.Errorf("failed to run gcc: %w", err)
	}

	recordTypeToNum := map[string]int{}
	s := bufio.NewScanner(bytes.NewReader(out))
	for s.Scan() {
		matches := recordTypeDefinitionRegex.FindStringSubmatch(s.Text())
		if len(matches) != 3 {
			continue
		}
		recordNum, _ := strconv.Atoi(matches[2])

		// Filter constants.
		if recordNum >= minRecordNum && recordNum <= maxRecordNum {
			recordTypeToNum[matches[1]] = recordNum
		}
	}

	return recordTypeToNum, s.Err()
}

func run() error {
	tmp, err := os.MkdirTemp("", "mk_audit_msg_types")
	if err != nil {
		return err
	}
	defer os.RemoveAll(tmp)

	// Download header files from the Linux audit project.
	var files []string
	for _, url := range headers {
		f, err := DownloadFile(url, tmp)
		if err != nil {
			return fmt.Errorf("download failed for %v: %v", url, err)
		}
		files = append(files, f)
	}

	if err := os.Chdir(tmp); err != nil {
		return err
	}

	recordTypeToStringName, err := readMessageTypeTable()
	if err != nil {
		return err
	}

	recordTypeToNum, err := readRecordTypes()
	if err != nil {
		return err
	}

	numToRecordType := map[int]string{}
	for recordType := range recordTypeToStringName {
		num, found := recordTypeToNum[recordType]
		if !found {
			return fmt.Errorf("missing definition of %v", recordType)
		}
		numToRecordType[num] = recordType
	}

	for recordType, num := range recordTypeToNum {
		// Do not replace existing mappings.
		if _, found := numToRecordType[num]; found {
			continue
		}

		numToRecordType[num] = recordType

		matches := nameExtractorRegex.FindStringSubmatch(recordType)
		if len(matches) == 2 {
			recordTypeToStringName[recordType] = matches[1]
		}
	}

	// Create output file.
	f, err := os.Create(flagOut)
	if err != nil {
		return err
	}
	defer f.Close()

	// Evaluate template.
	r := TemplateParams{
		Command:     filepath.Base(os.Args[0]),
		RecordTypes: numToRecordType,
		RecordNames: recordTypeToStringName,
	}
	if err := tmpl.Execute(f, r); err != nil {
		return err
	}

	return nil
}

var flagOut string

func main() {
	flag.StringVar(&flagOut, "out", "zaudit_msg_types.go", "output file")
	flag.Parse()

	var err error
	flagOut, err = filepath.Abs(flagOut)
	if err != nil {
		fmt.Fprintf(os.Stderr, "error: %v\n", err)
		os.Exit(1)
	}

	if err := run(); err != nil {
		fmt.Fprintf(os.Stderr, "error: %v\n", err)
		os.Exit(1)
	}
}
